{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%install '.package(path: \"$cwd/FastaiNotebook_06_cuda\")' FastaiNotebook_06_cuda"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import FastaiNotebook_06_cuda\n",
    "%include \"EnableIPythonDisplay.swift\"\n",
    "IPythonDisplay.shell.enable_matplotlib(\"inline\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "//export\n",
    "import Path\n",
    "import TensorFlow\n",
    "import Python"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's start by building our own batchnorm layer from scratch. Eventually we want something like this to work:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AlmostBatchNorm<Scalar: TensorFlowFloatingPoint> { // : Layer\n",
    "    // Configuration hyperparameters\n",
    "    let momentum, epsilon: Scalar\n",
    "    // Running statistics\n",
    "    var runningMean, runningVariance: Tensor<Scalar>\n",
    "    // Trainable parameters\n",
    "    var scale, offset: Tensor<Scalar>\n",
    "    \n",
    "    init(featureCount: Int, momentum: Scalar = 0.9, epsilon: Scalar = 1e-5) {\n",
    "        (self.momentum, self.epsilon) = (momentum, epsilon)\n",
    "        (scale, offset) = (Tensor(ones: [featureCount]), Tensor(zeros: [featureCount]))\n",
    "        (runningMean, runningVariance) = (Tensor(0), Tensor(1))\n",
    "    }\n",
    "\n",
    "    func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {\n",
    "        let mean, variance: Tensor<Scalar>\n",
    "        switch Context.local.learningPhase {\n",
    "        case .training:\n",
    "            mean = input.mean(alongAxes: [0, 1, 2])\n",
    "            variance = input.variance(alongAxes: [0, 1, 2])\n",
    "            runningMean += (mean - runningMean) * (1 - momentum)\n",
    "            runningVariance += (variance - runningVariance) * (1 - momentum)\n",
    "        case .inference:\n",
    "            (mean, variance) = (runningMean, runningVariance)\n",
    "        }\n",
    "        let normalizer = rsqrt(variance + epsilon) * scale\n",
    "        return (input - mean) * normalizer + offset\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "But there are some automatic differentiation limitations (lack of support for classes and control flow) that make this impossible for now, so we'll need a few workarounds. A `Reference` will let us update running statistics without making the layer a class or declaring the `applied` method `mutating`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "//export\n",
    "public class Reference<T> {\n",
    "    public var value: T\n",
    "    public init(_ value: T) { self.value = value }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following snippet will let us differentiate a layer's `forward` method (which is the one called in `call` for `FALayer`) if it's composed of training and inference implementations that are each differentiable:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "//export\n",
    "public protocol LearningPhaseDependent: FALayer {\n",
    "    associatedtype Input\n",
    "    associatedtype Output\n",
    "    @differentiable func forwardTraining (_ input: Input) -> Output\n",
    "    @differentiable func forwardInference(_ input: Input) -> Output\n",
    "}\n",
    "\n",
    "extension LearningPhaseDependent {\n",
    "    public func forward(_ input: Input) -> Output {\n",
    "        switch Context.local.learningPhase {\n",
    "        case .training:  return forwardTraining(input)\n",
    "        case .inference: return forwardInference(input)\n",
    "        }\n",
    "    }\n",
    "\n",
    "    @differentiating(forward)\n",
    "    func gradForward(_ input: Input) ->\n",
    "        (value: Output, pullback: (Self.Output.TangentVector) ->\n",
    "            (Self.TangentVector, Self.Input.TangentVector)) {\n",
    "        switch Context.local.learningPhase {\n",
    "        case .training:  return valueWithPullback(at: input) { $0.forwardTraining($1)  }\n",
    "        case .inference: return valueWithPullback(at: input) { $0.forwardInference($1) }\n",
    "        }\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can implement a BatchNorm that we can use in our models:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "//export\n",
    "public protocol Norm: FALayer where Input == TF, Output == TF {\n",
    "    init(_ featureCount: Int, epsilon: Float)\n",
    "}\n",
    "\n",
    "public struct FABatchNorm: LearningPhaseDependent, Norm {\n",
    "    // Configuration hyperparameters\n",
    "    @noDerivative var momentum, epsilon: Float\n",
    "    // Running statistics\n",
    "    @noDerivative let runningMean, runningVariance: Reference<TF>\n",
    "    // Trainable parameters\n",
    "    public var scale, offset: TF\n",
    "    \n",
    "    public init(_ featureCount: Int, momentum: Float, epsilon: Float = 1e-5) {\n",
    "        self.momentum = momentum\n",
    "        self.epsilon = epsilon\n",
    "        self.scale = Tensor(ones: [featureCount])\n",
    "        self.offset = Tensor(zeros: [featureCount])\n",
    "        self.runningMean = Reference(Tensor(0))\n",
    "        self.runningVariance = Reference(Tensor(1))\n",
    "    }\n",
    "    \n",
    "    public init(_ featureCount: Int, epsilon: Float = 1e-5) {\n",
    "        self.init(featureCount, momentum: 0.9, epsilon: epsilon)\n",
    "    }\n",
    "\n",
    "    @differentiable\n",
    "    public func forwardTraining(_ input: TF) -> TF {\n",
    "        let mean = input.mean(alongAxes: [0, 1, 2])\n",
    "        let variance = input.variance(alongAxes: [0, 1, 2])\n",
    "        runningMean.value += (mean - runningMean.value) * (1 - momentum)\n",
    "        runningVariance.value += (variance - runningVariance.value) * (1 - momentum)\n",
    "        let normalizer = rsqrt(variance + epsilon) * scale\n",
    "        return (input - mean) * normalizer + offset\n",
    "    }\n",
    "    \n",
    "    @differentiable\n",
    "    public func forwardInference(_ input: TF) -> TF {\n",
    "        let (mean, variance) = (runningMean.value, runningVariance.value)\n",
    "        let normalizer = rsqrt(variance + epsilon) * scale\n",
    "        return (input - mean) * normalizer + offset\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here is a generic `ConvNorm` layer, that combines a conv2d and a norm (like batchnorm, running batchnorm etc...) layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "//export\n",
    "public struct ConvNorm<NormType: Norm & FALayer>: FALayer\n",
    "    where NormType.AllDifferentiableVariables == NormType.TangentVector {\n",
    "    public var conv: FANoBiasConv2D<Float>\n",
    "    public var norm: NormType\n",
    "    \n",
    "    public init(_ cIn: Int, _ cOut: Int, ks: Int = 3, stride: Int = 2){\n",
    "        self.conv = FANoBiasConv2D(cIn, cOut, ks: ks, stride: stride, activation: relu) \n",
    "        self.norm = NormType(cOut, epsilon: 1e-5)\n",
    "    }\n",
    "\n",
    "    @differentiable\n",
    "    public func forward(_ input: Tensor<Float>) -> Tensor<Float> {\n",
    "        return norm(conv(input))\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "//export\n",
    "public struct CnnModelNormed<NormType: Norm & FALayer>: FALayer\n",
    "    where NormType.AllDifferentiableVariables == NormType.TangentVector {\n",
    "    public var convs: [ConvNorm<NormType>]\n",
    "    public var pool = FAGlobalAvgPool2D<Float>()\n",
    "    public var linear: FADense<Float>\n",
    "    \n",
    "    public init(channelIn: Int, nOut: Int, filters: [Int]){\n",
    "        let allFilters = [channelIn] + filters\n",
    "        convs = Array(0..<filters.count).map { i in\n",
    "            return ConvNorm<NormType>(allFilters[i], allFilters[i+1], ks: 3, stride: 2)\n",
    "        }\n",
    "        linear = FADense<Float>(filters.last!, nOut)\n",
    "    }\n",
    "    \n",
    "    @differentiable\n",
    "    public func forward(_ input: TF) -> TF {\n",
    "        // TODO: Work around https://bugs.swift.org/browse/TF-606\n",
    "        return linear.forward(pool.forward(convs(input)))\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's benchmark this batchnorm implementation!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "func benchmark(forward: () -> (), backward: () -> ()) {\n",
    "    print(\"forward:\")\n",
    "    time(repeating: 10, forward)\n",
    "    print(\"backward:\")\n",
    "    time(repeating: 10, backward)\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "let input = TF(randomUniform: [64, 28, 28, 32])\n",
    "let norm = FABatchNorm(32)\n",
    "let pb = pullback(at: input) { x in norm(x) }\n",
    "benchmark(forward: { norm(input) }, backward: { pb(input) })"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Yikes, that's pretty bad. Luckily, TensorFlow has a built-in fused batchnorm layer. Let's see how the performance looks for that:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "let input = TF(randomUniform: [64, 28, 28, 32])\n",
    "let norm = FABatchNorm(32)\n",
    "let bnresult = Raw.fusedBatchNormV2(\n",
    "    input, scale: norm.scale, offset: norm.offset, \n",
    "    mean: TF([] as [Float]), variance: TF([] as [Float]), \n",
    "    epsilon: Double(norm.epsilon))\n",
    "benchmark(\n",
    "    forward: {\n",
    "        Raw.fusedBatchNormV2(\n",
    "            input, scale: norm.scale, offset: norm.offset, \n",
    "            mean: TF([] as [Float]), variance: TF([] as [Float]), \n",
    "            epsilon: Double(norm.epsilon))\n",
    "    },\n",
    "    backward: {\n",
    "        Raw.fusedBatchNormGradV2(\n",
    "            yBackprop: input, input, scale: TF(norm.scale), \n",
    "            reserveSpace1: bnresult.reserveSpace1, \n",
    "            reserveSpace2: bnresult.reserveSpace2, \n",
    "            epsilon: Double(norm.epsilon))\n",
    "    })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "struct PullbackArgs<T : TensorGroup, U : TensorGroup> : TensorGroup {\n",
    "    let input: T\n",
    "    let cotangent: U\n",
    "}\n",
    "\n",
    "class CompiledFunction<Input: Differentiable & TensorGroup, Output: Differentiable & TensorGroup> {\n",
    "    let f: @differentiable (Input) -> Output\n",
    "    init(_ f: @escaping @differentiable (Input) -> Output) {\n",
    "        self.f = f\n",
    "    }\n",
    "}\n",
    "\n",
    "func xlaCompiled<T : Differentiable & TensorGroup, U : Differentiable & TensorGroup>(\n",
    "    _ fn: @escaping @differentiable (T) -> U) -> CompiledFunction<T, U>\n",
    "    where T.TangentVector : TensorGroup, U.TangentVector : TensorGroup {\n",
    "    let xlaCompiledFn: (T) -> U = _graph(fn, useXLA: true)\n",
    "    let xlaCompiledPullback = _graph(\n",
    "        { (pbArgs: PullbackArgs<T, U.TangentVector>) in\n",
    "            pullback(at: pbArgs.input, in: fn)(pbArgs.cotangent) },\n",
    "        useXLA: true\n",
    "    )\n",
    "    return CompiledFunction(differentiableFunction { x in\n",
    "        (value: xlaCompiledFn(x), pullback: { v in\n",
    "            xlaCompiledPullback(PullbackArgs(input: x, cotangent: v))})\n",
    "    })\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "struct TrainingKernelInput: TensorGroup, Differentiable, AdditiveArithmetic {\n",
    "    var input, scale, offset, runningMean, runningVariance, momentum, epsilon: TF\n",
    "}\n",
    "\n",
    "struct TrainingKernelOutput: TensorGroup, Differentiable, AdditiveArithmetic {\n",
    "    var normalized, newRunningMean, newRunningVariance: TF\n",
    "}\n",
    "\n",
    "@differentiable\n",
    "func trainingKernel(_ input: TrainingKernelInput) -> TrainingKernelOutput {\n",
    "    let mean = input.input.mean(alongAxes: [0, 1, 2])\n",
    "    let variance = input.input.variance(alongAxes: [0, 1, 2])\n",
    "    let invMomentum = TF(1) - input.momentum\n",
    "    let newRunningMean = input.runningMean * input.momentum + mean * invMomentum\n",
    "    let newRunningVariance = input.runningVariance * input.momentum + variance * invMomentum\n",
    "    let normalizer = rsqrt(variance + input.epsilon) * input.scale\n",
    "    let normalized = (input.input - mean) * normalizer + input.offset\n",
    "    return TrainingKernelOutput(\n",
    "        normalized: normalized,\n",
    "        newRunningMean: newRunningMean,\n",
    "        newRunningVariance: newRunningVariance\n",
    "    )\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "let input = TF(randomUniform: [64, 28, 28, 32])\n",
    "let norm = FABatchNorm(32)\n",
    "let compiledTrainingKernel = xlaCompiled(trainingKernel)\n",
    "let kernelInput = TrainingKernelInput(\n",
    "    input: input,\n",
    "    scale: norm.scale,\n",
    "    offset: norm.offset,\n",
    "    runningMean: norm.runningMean.value,\n",
    "    runningVariance: norm.runningVariance.value,\n",
    "    momentum: Tensor(norm.momentum),\n",
    "    epsilon: Tensor(norm.epsilon))\n",
    "let pb = pullback(at: kernelInput) { x in compiledTrainingKernel.f(x) }\n",
    "let kernelOutput = compiledTrainingKernel.f(kernelInput)\n",
    "\n",
    "benchmark(\n",
    "    forward: { compiledTrainingKernel.f(kernelInput) },\n",
    "    backward: { pb(kernelOutput) })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Swift",
   "language": "swift",
   "name": "swift"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
